import dbnomicsVisualizing the Philips Curve
The Philips Curve was initially discovered as a statistical relationship between unemployment and inflation. The original version used historical US data.

Our goal here is to visually inspect the Philips curve using recent data, for several countries.
In the process we will learn to: - import dataframes, inspect them, merge them, clean the resulting data - use matplotlib to create graphs - bonus: experiment with other plotting libraries
Importing the Data
We start by loading library dbnomics which contains all the data we want. It is installed already on the nuvolos server.
The following code imports data for from dbnomics for a few countries.
table_1 = dbnomics.fetch_series(
[
"OECD/DP_LIVE/FRA.CPI.TOT.AGRWTH.Q",
"OECD/DP_LIVE/GBR.CPI.TOT.AGRWTH.Q",
"OECD/DP_LIVE/USA.CPI.TOT.AGRWTH.Q",
"OECD/DP_LIVE/DEU.CPI.TOT.AGRWTH.Q"
]
)table_2 = dbnomics.fetch_series([
"OECD/MEI/DEU.LRUNTTTT.STSA.Q",
"OECD/MEI/FRA.LRUNTTTT.STSA.Q",
"OECD/MEI/USA.LRUNTTTT.STSA.Q",
"OECD/MEI/GBR.LRUNTTTT.STSA.Q"
])Describe concisely the data that has been imported (periodicity, type of measure, …). You can either check dbnomics website or look at the databases.
The data comes from dbnomics. Provider is OECD. Database is “Data Live dataset” for inflation, and “Main Economic Indicators Publication” for unemployement.
Data is for several countries (Germany, France, USA, Great Britain).
- inflation: Consumer Price Index for all goods and services (total), in annual growth rate, measured every quarter
- unemployment: Labour Force Survey - quarterly rates , workers aged 15 or over
Show the first rows of each database. Make a list of all columns.
table_1.head(2)| @frequency | provider_code | dataset_code | dataset_name | series_code | series_name | original_period | period | original_value | value | LOCATION | INDICATOR | SUBJECT | MEASURE | FREQUENCY | Country | Indicator | Subject | Measure | Frequency | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | quarterly | OECD | DP_LIVE | OECD Data Live dataset | FRA.CPI.TOT.AGRWTH.Q | France – Inflation (CPI) – Total – Annual grow... | 1956-Q1 | 1956-01-01 | 1.746324 | 1.746324 | FRA | CPI | TOT | AGRWTH | Q | France | Inflation (CPI) | Total | Annual growth rate (%) | Quarterly |
| 1 | quarterly | OECD | DP_LIVE | OECD Data Live dataset | FRA.CPI.TOT.AGRWTH.Q | France – Inflation (CPI) – Total – Annual grow... | 1956-Q2 | 1956-04-01 | 1.838658 | 1.838658 | FRA | CPI | TOT | AGRWTH | Q | France | Inflation (CPI) | Total | Annual growth rate (%) | Quarterly |
table_1.columnsIndex(['@frequency', 'provider_code', 'dataset_code', 'dataset_name',
'series_code', 'series_name', 'original_period', 'period',
'original_value', 'value', 'LOCATION', 'INDICATOR', 'SUBJECT',
'MEASURE', 'FREQUENCY', 'Country', 'Indicator', 'Subject', 'Measure',
'Frequency'],
dtype='object')
table_2.columnsIndex(['@frequency', 'provider_code', 'dataset_code', 'dataset_name',
'series_code', 'series_name', 'original_period', 'period',
'original_value', 'value', 'LOCATION', 'SUBJECT', 'MEASURE',
'FREQUENCY', 'Country', 'Subject', 'Measure', 'Frequency'],
dtype='object')
Compute standard statistics for all variables
table_1.describe()| period | value | |
|---|---|---|
| count | 1084 | 1083.000000 |
| mean | 1989-09-30 18:46:29.667896704 | 3.891054 |
| min | 1956-01-01 00:00:00 | -1.623360 |
| 25% | 1972-10-01 00:00:00 | 1.680341 |
| 50% | 1989-10-01 00:00:00 | 2.724924 |
| 75% | 2006-10-01 00:00:00 | 5.002851 |
| max | 2023-07-01 00:00:00 | 26.565810 |
| std | NaN | 3.591647 |
table_2.describe()| period | original_value | value | |
|---|---|---|---|
| count | 817 | 817.000000 | 817.000000 |
| mean | 1994-11-08 05:59:33.561811584 | 6.098632 | 6.098632 |
| min | 1955-01-01 00:00:00 | 0.373774 | 0.373774 |
| 25% | 1979-07-01 00:00:00 | 4.366667 | 4.366667 |
| 50% | 1996-07-01 00:00:00 | 5.833333 | 5.833333 |
| 75% | 2011-01-01 00:00:00 | 7.996948 | 7.996948 |
| max | 2023-10-01 00:00:00 | 13.066667 | 13.066667 |
| std | NaN | 2.550835 | 2.550835 |
Average inflation over the period for all the countries is 3.89. Average unemployement over the period for all the countries is 6.09.
Compute averages and standard deviations for unemployment and inflation, per country.
# option 1: by using pandas boolean selection
# we want to extract a subdataframe for each countryind = table_1['Country'] == 'France'
table_1_fr = table_1[ ind ]# what are the unique values taken by the column country ?
# set(table_1['Country']) # pure python
table_1['Country'].unique()array(['France', 'United Kingdom', 'United States', 'Germany'],
dtype=object)
table_1_fra = table_1.query("Country=='France'")
table_1_gbr = table_1.query("Country=='United Kingdom'")
table_1_deu = table_1.query("Country=='Germany'")
table_1_usa = table_1.query("Country=='United States'")d = dict()
for country in ["France", "United Kingdom", "Germany", "United States"]:
d[country] = table_1.query(f"Country=='{country}'")# list comprehension
[table_1.query(f"Country=='{country}'") for country in table_1['Country'].unique()];# dictionary comprehension
d = {country: table_1.query(f"Country=='{country}'") for country in table_1['Country'].unique()}for k,v in d.items():
print(f"{k}, mean: {v['value'].mean()}")France, mean: 4.218004559985239
United Kingdom, mean: 5.003996036162361
United States, mean: 3.678550917822878
Germany, mean: 2.659118566666667
# option 2: by using groupby
table_1.groupby("Country")['value'].agg('mean')Country
France 4.218005
Germany 2.659119
United Kingdom 5.003996
United States 3.678551
Name: value, dtype: float64
#standard devition
table_1.groupby("Country")['value'].agg('std')Country
France 3.853190
Germany 1.866131
United Kingdom 4.768593
United States 2.779505
Name: value, dtype: float64
table_1.groupby("Country")['value'].agg(['mean','std'])| mean | std | |
|---|---|---|
| Country | ||
| France | 4.218005 | 3.853190 |
| Germany | 2.659119 | 1.866131 |
| United Kingdom | 5.003996 | 4.768593 |
| United States | 3.678551 | 2.779505 |
table_1.groupby("Country")['value'].agg('describe')| count | mean | std | min | 25% | 50% | 75% | max | |
|---|---|---|---|---|---|---|---|---|
| Country | ||||||||
| France | 271.0 | 4.218005 | 3.853190 | -0.423247 | 1.660842 | 2.670692 | 5.765484 | 18.565260 |
| Germany | 270.0 | 2.659119 | 1.866131 | -0.922850 | 1.421377 | 2.107098 | 3.441268 | 8.580543 |
| United Kingdom | 271.0 | 5.003996 | 4.768593 | -0.453172 | 2.000000 | 3.244983 | 6.133533 | 26.565810 |
| United States | 271.0 | 3.678551 | 2.779505 | -1.623360 | 1.784109 | 3.023983 | 4.523969 | 14.505600 |
The following command merges the two databases together. Explain the role of argument on. What happened to the column names?
# we have two dataframes with similar columnstable_1.columnsIndex(['@frequency', 'provider_code', 'dataset_code', 'dataset_name',
'series_code', 'series_name', 'original_period', 'period',
'original_value', 'value', 'LOCATION', 'INDICATOR', 'SUBJECT',
'MEASURE', 'FREQUENCY', 'Country', 'Indicator', 'Subject', 'Measure',
'Frequency'],
dtype='object')
table_2.columnsIndex(['@frequency', 'provider_code', 'dataset_code', 'dataset_name',
'series_code', 'series_name', 'original_period', 'period',
'original_value', 'value', 'LOCATION', 'SUBJECT', 'MEASURE',
'FREQUENCY', 'Country', 'Subject', 'Measure', 'Frequency'],
dtype='object')
table = table_1.merge(table_2, on=["period", 'Country']) table.columnsIndex(['@frequency_x', 'provider_code_x', 'dataset_code_x', 'dataset_name_x',
'series_code_x', 'series_name_x', 'original_period_x', 'period',
'original_value_x', 'value_x', 'LOCATION_x', 'INDICATOR', 'SUBJECT_x',
'MEASURE_x', 'FREQUENCY_x', 'Country', 'Indicator', 'Subject_x',
'Measure_x', 'Frequency_x', '@frequency_y', 'provider_code_y',
'dataset_code_y', 'dataset_name_y', 'series_code_y', 'series_name_y',
'original_period_y', 'original_value_y', 'value_y', 'LOCATION_y',
'SUBJECT_y', 'MEASURE_y', 'FREQUENCY_y', 'Subject_y', 'Measure_y',
'Frequency_y'],
dtype='object')
We rename the new names for the sake of clarity and normalize everything with lower cases.
table = table.rename(columns={
'period': 'date', # because it sounds more natural
'Country': 'country',
'value_x': 'inflation',
'value_y': 'unemployment'
})On the merged table, compute at once all the statistics computed before (use groupby and agg).
table.groupby('country')[ ['unemployment', 'inflation'] ].agg('mean')| unemployment | inflation | |
|---|---|---|
| country | ||
| France | 8.680560 | 1.664349 |
| Germany | 4.989272 | 2.730136 |
| United Kingdom | 6.705114 | 5.404707 |
| United States | 5.880812 | 3.678551 |
# the resulting dataframe sitll has horrible column names
table.head()| @frequency_x | provider_code_x | dataset_code_x | dataset_name_x | series_code_x | series_name_x | original_period_x | date | original_value_x | inflation | ... | original_period_y | original_value_y | unemployment | LOCATION_y | SUBJECT_y | MEASURE_y | FREQUENCY_y | Subject_y | Measure_y | Frequency_y | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | quarterly | OECD | DP_LIVE | OECD Data Live dataset | FRA.CPI.TOT.AGRWTH.Q | France – Inflation (CPI) – Total – Annual grow... | 2003-Q1 | 2003-01-01 | 2.366263 | 2.366263 | ... | 2003-Q1 | 7.922234 | 7.922234 | FRA | LRUNTTTT | STSA | Q | Labour Force Survey - quarterly rates > Unempl... | Level, rate or national currency, s.a. | Quarterly |
| 1 | quarterly | OECD | DP_LIVE | OECD Data Live dataset | FRA.CPI.TOT.AGRWTH.Q | France – Inflation (CPI) – Total – Annual grow... | 2003-Q2 | 2003-04-01 | 1.912854 | 1.912854 | ... | 2003-Q2 | 8.089598 | 8.089598 | FRA | LRUNTTTT | STSA | Q | Labour Force Survey - quarterly rates > Unempl... | Level, rate or national currency, s.a. | Quarterly |
| 2 | quarterly | OECD | DP_LIVE | OECD Data Live dataset | FRA.CPI.TOT.AGRWTH.Q | France – Inflation (CPI) – Total – Annual grow... | 2003-Q3 | 2003-07-01 | 1.93227 | 1.932270 | ... | 2003-Q3 | 8.036090 | 8.036090 | FRA | LRUNTTTT | STSA | Q | Labour Force Survey - quarterly rates > Unempl... | Level, rate or national currency, s.a. | Quarterly |
| 3 | quarterly | OECD | DP_LIVE | OECD Data Live dataset | FRA.CPI.TOT.AGRWTH.Q | France – Inflation (CPI) – Total – Annual grow... | 2003-Q4 | 2003-10-01 | 2.184437 | 2.184437 | ... | 2003-Q4 | 8.349410 | 8.349410 | FRA | LRUNTTTT | STSA | Q | Labour Force Survey - quarterly rates > Unempl... | Level, rate or national currency, s.a. | Quarterly |
| 4 | quarterly | OECD | DP_LIVE | OECD Data Live dataset | FRA.CPI.TOT.AGRWTH.Q | France – Inflation (CPI) – Total – Annual grow... | 2004-Q1 | 2004-01-01 | 1.800087 | 1.800087 | ... | 2004-Q1 | 8.518631 | 8.518631 | FRA | LRUNTTTT | STSA | Q | Labour Force Survey - quarterly rates > Unempl... | Level, rate or national currency, s.a. | Quarterly |
5 rows × 36 columns
Before we process further, we should tidy the dataframe by keeping only what we need. - Keep only the columns date, country, inflation and unemployment - Drop all na values - Make a copy of the result
# there are some nas in the dataframe
sum(df['inflation'].isna())1
table[
['inflation', 'unemployment'] # list of columns to select
];
table[['inflation', 'unemployment']];df = table[['date', 'country', 'inflation', 'unemployment']].dropna()df = df.copy()
# note: the copy() function is here to avoid keeping references to the original databaseWhat is the maximum available interval for each country? How would you proceed to keep only those dates where all datas are available? In the following we keep the resulting “cylindric” database.
Our DataFrame is now ready for further analysis !
Plotting using matplotlib
Our goal now consists in plotting inflation against unemployment to see whether a pattern emerges. We will first work on France.
from matplotlib import pyplot as pltCreate a database df_fr which contains only the data for France.
df_fr = df.query("country=='France'")The following command create a line plot for inflation against unemployment. Can you transform it into a scatterplot ?
plt.plot(df_fr['unemployment'], df_fr['inflation']) # missing 'o'
# create a scatter plot
plt.plot(df_fr['unemployment'], df_fr['inflation'], 'o')
Expand the above command to make the plot nicer (label, title, grid, …)
# create a scatter plot
plt.plot(df_fr['unemployment'], df_fr['inflation'], 'o')
plt.title("Philips Curve")
plt.xlabel("Unemployment")
plt.ylabel("Inflation")
plt.grid()
Visualizing the regression
The following piece of code regresses inflation on unemployment.
from statsmodels.formula import api as sm
model = sm.ols(formula='inflation ~ unemployment', data=df_fr)
result = model.fit()
result.summary()| Dep. Variable: | inflation | R-squared: | 0.435 |
| Model: | OLS | Adj. R-squared: | 0.428 |
| Method: | Least Squares | F-statistic: | 62.46 |
| Date: | Wed, 07 Feb 2024 | Prob (F-statistic): | 1.17e-11 |
| Time: | 10:39:59 | Log-Likelihood: | -120.37 |
| No. Observations: | 83 | AIC: | 244.7 |
| Df Residuals: | 81 | BIC: | 249.6 |
| Df Model: | 1 | ||
| Covariance Type: | nonrobust |
| coef | std err | t | P>|t| | [0.025 | 0.975] | |
| Intercept | 9.7054 | 1.024 | 9.479 | 0.000 | 7.668 | 11.743 |
| unemployment | -0.9263 | 0.117 | -7.903 | 0.000 | -1.160 | -0.693 |
| Omnibus: | 9.014 | Durbin-Watson: | 0.241 |
| Prob(Omnibus): | 0.011 | Jarque-Bera (JB): | 10.906 |
| Skew: | 0.520 | Prob(JB): | 0.00428 |
| Kurtosis: | 4.439 | Cond. No. | 79.0 |
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
We can use the resulting model to “predict” inflation from unemployment.
result.predict(df_fr['unemployment'])0 2.366810
1 2.211777
2 2.261342
3 1.971104
4 1.814349
...
78 3.064908
79 3.055976
80 3.141252
81 2.926298
82 2.888369
Length: 83, dtype: float64
Store the result in df_fr as a new column reg_unemployment
# df_fr.loc['reg_inflation'] = result.predict(df_fr['unemployment'])# no error message for full index specification
df_fr.loc[:,'reg_inflation'] = result.predict(df_fr['unemployment'])/tmp/ipykernel_63599/2161117277.py:2: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_fr.loc[:,'reg_inflation'] = result.predict(df_fr['unemployment'])
df_fr.head()| date | country | inflation | unemployment | reg_inflation | |
|---|---|---|---|---|---|
| 0 | 2003-01-01 | France | 2.366263 | 7.922234 | 2.366810 |
| 1 | 2003-04-01 | France | 1.912854 | 8.089598 | 2.211777 |
| 2 | 2003-07-01 | France | 1.932270 | 8.036090 | 2.261342 |
| 3 | 2003-10-01 | France | 2.184437 | 8.349410 | 1.971104 |
| 4 | 2004-01-01 | France | 1.800087 | 8.518631 | 1.814349 |
Add the regression line to the scatter plot.
# create a scatter plot
plt.plot(df_fr['unemployment'], df_fr['inflation'], 'o')
plt.plot(df_fr['unemployment'], df_fr['reg_inflation'])
plt.title("Philips Curve")
plt.xlabel("Unemployment")
plt.ylabel("Inflation")
plt.grid()
Now we would like to compare all countries. Can you find a way to represent the data for all of them (all on one graph, using subplots…) ?
df.head()| date | country | inflation | unemployment | |
|---|---|---|---|---|
| 0 | 2003-01-01 | France | 2.366263 | 7.922234 |
| 1 | 2003-04-01 | France | 1.912854 | 8.089598 |
| 2 | 2003-07-01 | France | 1.932270 | 8.036090 |
| 3 | 2003-10-01 | France | 2.184437 | 8.349410 |
| 4 | 2004-01-01 | France | 1.800087 | 8.518631 |
countries_list = df['country'].unique()
countries_listarray(['France', 'United Kingdom', 'United States', 'Germany'],
dtype=object)
from statsmodels.formula import api as sm
for country in countries_list:
print(f"country=='{country}'")
df_country = df.query(f"country=='{country}'")
model = sm.ols(formula='inflation ~ unemployment', data=df_country)
result = model.fit()
df_country.loc[:,'reg_inflation'] = result.predict(df_country['unemployment'])
# create a scatter plot
plt.plot(df_country['unemployment'], df_country['inflation'], '.', label=country, alpha=0.5)
plt.plot(df_country['unemployment'], df_country['reg_inflation'])
plt.title("Philips Curve")
plt.xlabel("Unemployment")
plt.ylabel("Inflation")
plt.legend()
plt.grid()country=='France'
country=='United Kingdom'
country=='United States'
country=='Germany'
/tmp/ipykernel_63599/2982871465.py:9: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_country.loc[:,'reg_inflation'] = result.predict(df_country['unemployment'])
/tmp/ipykernel_63599/2982871465.py:9: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_country.loc[:,'reg_inflation'] = result.predict(df_country['unemployment'])
/tmp/ipykernel_63599/2982871465.py:9: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_country.loc[:,'reg_inflation'] = result.predict(df_country['unemployment'])
/tmp/ipykernel_63599/2982871465.py:9: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_country.loc[:,'reg_inflation'] = result.predict(df_country['unemployment'])

# maybe nicer with subplots
plt.subplot(2,2,1)
plt.subplot(2,2,2)
plt.subplot(2,2,3)
plt.subplot(2,2,4)
# i = 0
# for country in countries_list:
# i = i + 1
# enumeration syntax
for i,country in enumerate(countries_list):
print(f"country=='{country}'")
df_country = df.query(f"country=='{country}'")
model = sm.ols(formula='inflation ~ unemployment', data=df_country)
result = model.fit()
df_country.loc[:,'reg_inflation'] = result.predict(df_country['unemployment'])
# create a scatter plot
plt.subplot(2,2, i+1)
plt.plot(df_country['unemployment'], df_country['inflation'], '.', label=country, alpha=0.5)
plt.plot(df_country['unemployment'], df_country['reg_inflation'])
plt.title("Philips Curve")
plt.xlabel("Unemployment")
plt.ylabel("Inflation")
plt.legend()
plt.grid()
plt.tight_layout()country=='France'
country=='United Kingdom'
country=='United States'
country=='Germany'
/tmp/ipykernel_63599/566381394.py:16: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_country.loc[:,'reg_inflation'] = result.predict(df_country['unemployment'])
/tmp/ipykernel_63599/566381394.py:16: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_country.loc[:,'reg_inflation'] = result.predict(df_country['unemployment'])
/tmp/ipykernel_63599/566381394.py:16: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_country.loc[:,'reg_inflation'] = result.predict(df_country['unemployment'])
/tmp/ipykernel_63599/566381394.py:16: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
df_country.loc[:,'reg_inflation'] = result.predict(df_country['unemployment'])

Any comment on these results?
Bonus: Visualizing data using altair
Altair is a visualization library (based on Vega-lite) which offers a different syntax to make plots.
It is well adapted to the exploration phase, as it can operate on a full database (without splitting it like we did for matplotlib). It also provides some data transformation tools like regressions, and ways to add some interactivity.
import altair as altchart = alt.Chart(df).mark_point()
chartThe following command makes a basic plot from the dataframe df which contains all the countries. Can you enhance it by providing a title and encoding information to distinguish the various countries (for instance colors)?
chart = alt.Chart(df).mark_point().encode(
x='unemployment',
y='inflation',
# add something here
)
chartThe following graph plots a regression line, but for all countries, it is rather meaningless. Can you restrict the data to France only?
# modify the following code
chart = alt.Chart(df).mark_point().encode(
x='unemployment',
y='inflation',
)
chart + chart.transform_regression('unemployment', 'inflation').mark_line()One way to visualize data consists in adding some interactivity. Add some title and click on the legend
#run first then modify the following code
multi = alt.selection_multi(fields=["country"])
legend = alt.Chart(df).mark_point().encode(
y=alt.Y('country:N', axis=alt.Axis(orient='right')),
color=alt.condition(multi, 'country:N', alt.value('lightgray'), legend=None)
).add_selection(multi)
chart_2 = alt.Chart(df).mark_point().encode(
x='unemployment',
y='inflation',
color=alt.condition(multi, 'country:N', alt.value('lightgray')),
# find a way to separate on the graph data from France and US
)
chart_2 | legend/home/pablo/.local/opt/micromamba/envs/escp/lib/python3.12/site-packages/altair/utils/deprecation.py:65: AltairDeprecationWarning: 'selection_multi' is deprecated. Use 'selection_point'
warnings.warn(message, AltairDeprecationWarning, stacklevel=1)
/home/pablo/.local/opt/micromamba/envs/escp/lib/python3.12/site-packages/altair/utils/deprecation.py:65: AltairDeprecationWarning: 'add_selection' is deprecated. Use 'add_params' instead.
warnings.warn(message, AltairDeprecationWarning, stacklevel=1)
Bonus question: in the following graph you can select an interval in the left panel to select some subsample. Can you add the regression line(s) corresponding to the selected data to the last graph?
brush = alt.selection_interval(encodings=['x'],)
historical_chart_1 = alt.Chart(df).mark_line().encode(
x='date',
y='unemployment',
color='country'
).add_selection(
brush
)
historical_chart_2 = alt.Chart(df).mark_line().encode(
x='date',
y='inflation',
color='country'
)
chart = alt.Chart(df).mark_point().encode(
x='unemployment',
y='inflation',
# find a way to separate on the graph data from France and US
color=alt.condition(brush, 'country:N', alt.value('lightgray'))
)
alt.hconcat(historical_chart_1, historical_chart_2, chart,)/home/pablo/.local/opt/micromamba/envs/escp/lib/python3.12/site-packages/altair/utils/deprecation.py:65: AltairDeprecationWarning: 'add_selection' is deprecated. Use 'add_params' instead.
warnings.warn(message, AltairDeprecationWarning, stacklevel=1)
Bonus 2: Plotly Express
Another popular option is the plotly library for nice-looking interactive plots. Combined with dash or shiny, it can be used to build very powerful interactive interfaces.
import plotly.express as pxfig = px.scatter(df, x='unemployment', y='inflation', color='country', title="Philips Curves")
fig/home/pablo/.local/opt/micromamba/envs/escp/lib/python3.12/site-packages/plotly/express/_core.py:2065: FutureWarning: When grouping with a length-1 list-like, you will need to pass a length-1 tuple to get_group in a future version of pandas. Pass `(name,)` instead of `name` to silence this warning.
sf: grouped.get_group(s if len(s) > 1 else s[0])
Improve the graph above in any way you like